定義模型後,可以開始
model = CNN()
epochs = 100
batch_size = 300
lr = 1e-3
建立模組、定義訓練次數、批次大小和學習率
xt = train_data.data[batch_CNN].detach()
yt = train_data.train_labels[batch_CNN].detach()
pred = model(xt)
pred_labels = torch.argmax(pred,dim=1)
產生批次後,y軸取得標籤資訊,在找出預測的最大值作為結果
acc_ = 100.0 * (pred_labels == yt).sum() / batch_size
print('Current training accuracy: ', acc_.item())
計算準確率並預測
plt.figure(figsize=(10,7))
plt.xlabel("Training Epochs", fontsize=12)
plt.ylabel("Training accuracy", fontsize=12)
plt.plot(acc_CNN)
可以畫出每次epoch的準確度,從張圖我們可以知道訓練週期越到後面,模型的準確率也越高。
for i in range (10):
x = test_data.data[test_id[i]]
plt.imshow(x)
print('\n預測數字是:', pred_ind[i])
輸出測試資料,可以看到每張圖型和預測結果